#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import logging
from typing import Dict, List, Optional, Tuple

import torch
from pytext import resources
from pytext.common.constants import Stage, TORCH_VERSION
from pytext.config import ConfigBase
from pytext.data.roberta_tensorizer import (
    RoBERTaTensorizer,
    RoBERTaTokenLevelTensorizer,
)
from pytext.data.tensorizers import (
    FloatListTensorizer,
    LabelTensorizer,
    NumericLabelTensorizer,
    Tensorizer,
)
from pytext.models.bert_classification_models import NewBertModel
from pytext.models.bert_regression_model import NewBertRegressionModel
from pytext.models.decoders.mlp_decoder import MLPDecoder
from pytext.models.model import BaseModel
from pytext.models.module import create_module, Module
from pytext.models.output_layers import WordTaggingOutputLayer
from pytext.models.representations.transformer import (
    MultiheadLinearAttention,
    PassthroughEncoder,
    PassthroughTransformer,
    PostEncoder,
    QuantizedMultiheadLinearAttention,
    SELFIETransformer,
    SentenceEncoder,
)
from pytext.models.representations.transformer_sentence_encoder_base import (
    PoolingMethod,
    TransformerSentenceEncoderBase,
)
from pytext.models.utils import normalize_embeddings
from pytext.torchscript.module import (
    ScriptPyTextEmbeddingModuleIndex,
    ScriptPyTextModule,
    ScriptPyTextModuleWithDense,
)
from pytext.utils.file_io import PathManager
from pytext.utils.usage import log_class_usage
from pytorch.text.fb.nn.modules.multihead_attention import MultiheadSelfAttention
from pytorch.text.fb.nn.modules.transformer import (
    DEFAULT_MAX_SEQUENCE_LENGTH,
    Transformer,
    TransformerLayer,
)
from torch import nn
from torch.serialization import default_restore_location

from .r3f_models import R3FConfigOptions, R3FPyTextMixin

if TORCH_VERSION >= (1, 11):
    from torch.ao.quantization import convert_jit, get_default_qconfig, prepare_jit
else:
    from torch.quantization import convert_jit, get_default_qconfig, prepare_jit

logger = logging.getLogger(name=__name__)


def init_params(module):
    """Initialize the RoBERTa weights for pre-training from scratch."""

    if isinstance(module, torch.nn.Linear):
        module.weight.data.normal_(mean=0.0, std=0.02)
        if module.bias is not None:
            module.bias.data.zero_()
    if isinstance(module, torch.nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=0.02)
        if module.padding_idx is not None:
            module.weight.data[module.padding_idx].zero_()


class RoBERTaEncoderBase(TransformerSentenceEncoderBase):
    __EXPANSIBLE__ = True

    class Config(TransformerSentenceEncoderBase.Config):
        pass

    def _encoder(self, inputs, *args):
        # NewBertModel expects the output as a tuple and grabs the first element
        tokens, _, _, _ = inputs
        full_representation = (
            self.encoder(tokens, args) if len(args) > 0 else self.encoder(tokens)
        )
        sentence_rep = full_representation[-1][:, 0, :]
        return full_representation, sentence_rep


class RoBERTaEncoderJit(RoBERTaEncoderBase):
    """A TorchScript RoBERTa implementation"""

    class Config(RoBERTaEncoderBase.Config):
        pretrained_encoder: Module.Config = Module.Config(
            load_path=resources.roberta.PUBLIC
        )

    def __init__(self, config: Config, output_encoded_layers: bool, **kwarg) -> None:
        config.pretrained_encoder.load_path = (
            resources.roberta.RESOURCE_MAP[config.pretrained_encoder.load_path]
            if config.pretrained_encoder.load_path in resources.roberta.RESOURCE_MAP
            else config.pretrained_encoder.load_path
        )
        super().__init__(config, output_encoded_layers=output_encoded_layers)
        assert config.pretrained_encoder.load_path, "Load path cannot be empty."
        self.encoder = create_module(config.pretrained_encoder)
        self.representation_dim = self.encoder.encoder.token_embedding.weight.size(-1)
        log_class_usage(__class__)

    def _embedding(self):
        # used to tie weights in MaskedLM model
        return self.encoder.encoder.token_embedding


class RoBERTaEncoder(RoBERTaEncoderBase):
    """A PyTorch RoBERTa implementation"""

    class Config(RoBERTaEncoderBase.Config):
        embedding_dim: int = 768
        vocab_size: int = 50265
        num_encoder_layers: int = 12
        num_attention_heads: int = 12
        model_path: str = (
            "manifold://pytext_training/tree/static/models/roberta_base_torch.pt"
        )
        # Loading the state dict of the model depends on whether the model was
        # previously finetuned in PyText or not. If it was finetuned then we
        # dont need to translate the state dict and can just load it`
        # directly.
        is_finetuned: bool = False
        # If loading RoBERTaEncoder from a model where RoBERTaEncoder is just a
        # submodel, provide the submodel name here
        load_partial_model: Optional[str] = None
        max_seq_len: int = DEFAULT_MAX_SEQUENCE_LENGTH
        # Fine-tune bias parameters only (https://nlp.biu.ac.il/~yogo/bitfit.pdf)
        use_bias_finetuning: bool = False
        # Linformer hyperparameters
        use_linformer_encoder: bool = False
        linformer_compressed_ratio: int = 4
        linformer_quantize: bool = False
        export_encoder: bool = False
        variable_size_embedding: bool = True
        use_selfie_encoder: bool = False
        transformer_layer_to_keep: Optional[int] = None
        attention_heads_to_keep_per_layer: Optional[int] = None
        attention_heads_to_keep_per_layer_list: Optional[List[int]] = None
        prune_before_load: Optional[bool] = False
        scaling: Optional[float] = None
        normalize_before: bool = False
        skip_token_embed: bool = False
        use_mixup: bool = False

    def __init__(  # noqa C901
        self,
        config: Config,
        output_encoded_layers: bool,
        token_embedding: nn.Embedding = None,
        **kwarg,
    ) -> None:
        super().__init__(config, output_encoded_layers=output_encoded_layers)

        # map to the real model_path
        config.model_path = (
            resources.roberta.RESOURCE_MAP[config.model_path]
            if config.model_path in resources.roberta.RESOURCE_MAP
            else config.model_path
        )
        # assert config.pretrained_encoder.load_path, "Load path cannot be empty."
        # sharing compression across each layers

        # create compress layer if use linear multihead attention
        if config.use_linformer_encoder:
            compress_layer = nn.Linear(
                config.max_seq_len - 2,
                (config.max_seq_len - 2) // config.linformer_compressed_ratio,
            )

        self.use_selfie_encoder = config.use_selfie_encoder
        self.skip_token_embed = config.skip_token_embed

        if config.use_linformer_encoder:
            if config.linformer_quantize:
                layers = [
                    TransformerLayer(
                        embedding_dim=config.embedding_dim,
                        attention=QuantizedMultiheadLinearAttention(
                            embed_dim=config.embedding_dim,
                            num_heads=config.num_attention_heads,
                            compress_layer=compress_layer,
                        ),
                    )
                    for _ in range(config.num_encoder_layers)
                ]
            else:
                layers = [
                    TransformerLayer(
                        embedding_dim=config.embedding_dim,
                        attention=MultiheadLinearAttention(
                            embed_dim=config.embedding_dim,
                            num_heads=config.num_attention_heads,
                            compress_layer=compress_layer,
                        ),
                    )
                    for _ in range(config.num_encoder_layers)
                ]
        else:
            layers = [
                TransformerLayer(
                    embedding_dim=config.embedding_dim,
                    attention=MultiheadSelfAttention(
                        embed_dim=config.embedding_dim,
                        num_heads=config.num_attention_heads,
                        scaling=config.scaling,
                    ),
                    normalize_before=config.normalize_before,
                )
                for _ in range(config.num_encoder_layers)
            ]
        if not config.skip_token_embed:
            self.encoder = (
                SentenceEncoder(
                    transformer=Transformer(
                        vocab_size=config.vocab_size,
                        embedding_dim=config.embedding_dim,
                        layers=layers,
                        max_seq_len=config.max_seq_len,
                        normalize_before=config.normalize_before,
                        token_embedding=token_embedding,
                        use_mixup=config.use_mixup,
                    )
                )
                if not self.use_selfie_encoder
                else PostEncoder(
                    transformer=SELFIETransformer(
                        vocab_size=config.vocab_size,
                        embedding_dim=config.embedding_dim,
                        layers=layers,
                        max_seq_len=config.max_seq_len,
                    )
                )
            )
        else:
            self.encoder = PassthroughEncoder(
                transformer=PassthroughTransformer(
                    vocab_size=config.vocab_size,
                    embedding_dim=config.embedding_dim,
                    layers=layers,
                    max_seq_len=config.max_seq_len,
                    normalize_before=config.normalize_before,
                )
            )
        self.apply(init_params)

        if config.prune_before_load:
            self._prune_transformer_layers_and_heads(config)

        if config.model_path:
            with PathManager.open(config.model_path, "rb") as f:
                roberta_state = torch.load(
                    f, map_location=lambda s, l: default_restore_location(s, "cpu")
                )
            # In case the model has previously been loaded in PyText and finetuned,
            # then we dont need to do the special state dict translation. Load
            # it directly
            if not config.is_finetuned:
                self.encoder.load_roberta_state_dict(roberta_state["model"])
            elif config.load_partial_model is not None:
                roberta_state = {
                    k.replace(config.load_partial_model + ".", ""): v
                    for k, v in roberta_state["model_state"].items()
                    if k.startswith(config.load_partial_model)
                }
                self.load_state_dict(roberta_state)
            else:
                self.load_state_dict(roberta_state)

        if config.use_bias_finetuning:
            for (n, p) in self.encoder.named_parameters():
                # "encoder.transformer.layers.0.attention.input_projection.weight" -> false
                # "encoder.transformer.layers.0.attention.input_projection.bias" -> true
                if n.split(".")[-1] != "bias":
                    p.requires_grad_(False)

        if not config.prune_before_load:
            self._prune_transformer_layers_and_heads(config)

        self.export_encoder = config.export_encoder
        self.variable_size_embedding = config.variable_size_embedding
        self.use_linformer_encoder = config.use_linformer_encoder
        log_class_usage(__class__)

    def _prune_transformer_layers_and_heads(self, config: Config):
        if config.transformer_layer_to_keep is None:
            config.transformer_layer_to_keep = config.num_encoder_layers

        if config.transformer_layer_to_keep != config.num_encoder_layers:
            logger.info(f"prune the layers to {config.transformer_layer_to_keep}")
            self.encoder.transformer.layers = self.encoder.transformer.layers[
                0 : config.transformer_layer_to_keep
            ]

        if (
            config.attention_heads_to_keep_per_layer is not None
            or config.attention_heads_to_keep_per_layer_list is not None
        ):
            attention_heads_to_keep_per_layer_list = (
                config.attention_heads_to_keep_per_layer_list
                if config.attention_heads_to_keep_per_layer_list is not None
                else [config.attention_heads_to_keep_per_layer]
                * config.transformer_layer_to_keep
            )

            assert (
                len(attention_heads_to_keep_per_layer_list)
                == config.transformer_layer_to_keep
            )
            heads_to_prune = {
                i: list(
                    range(
                        config.num_attention_heads
                        - attention_heads_to_keep_per_layer_list[i]
                    )
                )
                for i in range(config.transformer_layer_to_keep)
            }
            for layer_index, heads in heads_to_prune.items():
                logger.info(f"prune layer {layer_index} heads by {len(heads)}")
                if config.use_linformer_encoder or config.use_selfie_encoder:
                    self.encoder.transformer.layers[
                        layer_index
                    ].attention.prune_multi_linear_heads(heads=heads)
                else:
                    self.encoder.transformer.layers[
                        layer_index
                    ].attention.prune_multi_heads(heads=heads)

    def _embedding(self):
        # used to tie weights in MaskedLM model
        return self.encoder.transformer.token_embedding

    def forward(
        self, input_tuple: Tuple[torch.Tensor, ...], *args
    ) -> Tuple[torch.Tensor, ...]:

        encoded_layers, pooled_output = (
            self._encoder(input_tuple, args[0])
            if self.use_selfie_encoder or self.skip_token_embed
            else self._encoder(input_tuple)
        )

        pad_mask = input_tuple[1]

        if self.pooling != PoolingMethod.CLS_TOKEN:
            pooled_output = self._pool_encoded_layers(encoded_layers, pad_mask)

        if self.projection:
            pooled_output = self.projection(pooled_output).tanh()

        if pooled_output is not None:
            pooled_output = self.output_dropout(pooled_output)
            if self.normalize_output_rep:
                pooled_output = normalize_embeddings(pooled_output)

        output = []
        if self.output_encoded_layers:
            output.append(encoded_layers)
        if self.pooling != PoolingMethod.NO_POOL:
            output.append(pooled_output)
        return tuple(output)


class RoBERTa(NewBertModel):
    class Config(NewBertModel.Config):
        class InputConfig(ConfigBase):
            tokens: RoBERTaTensorizer.Config = RoBERTaTensorizer.Config()
            dense: Optional[FloatListTensorizer.Config] = None
            labels: LabelTensorizer.Config = LabelTensorizer.Config()

        inputs: InputConfig = InputConfig()
        encoder: RoBERTaEncoderBase.Config = RoBERTaEncoderJit.Config()

    def trace(self, inputs):
        if self.encoder.export_encoder:
            return torch.jit.trace(self.encoder, inputs)
        else:
            return torch.jit.trace(self, inputs)

    def torchscriptify(self, tensorizers, traced_model):
        """Using the traced model, create a ScriptModule which has a nicer API that
        includes generating tensors from simple data types, and returns classified
        values according to the output layer (eg. as a dict mapping class name to score)
        """
        script_tensorizer = tensorizers["tokens"].torchscriptify()
        if self.encoder.export_encoder:
            return ScriptPyTextEmbeddingModuleIndex(
                traced_model, script_tensorizer, index=0
            )
        else:
            if "dense" in tensorizers:
                return ScriptPyTextModuleWithDense(
                    model=traced_model,
                    output_layer=self.output_layer.torchscript_predictions(),
                    tensorizer=script_tensorizer,
                    normalizer=tensorizers["dense"].normalizer,
                )
            else:
                return ScriptPyTextModule(
                    model=traced_model,
                    output_layer=self.output_layer.torchscript_predictions(),
                    tensorizer=script_tensorizer,
                )

    def graph_mode_quantize(
        self,
        inputs,
        data_loader,
        calibration_num_batches=64,
        qconfig_dict=None,
        force_quantize=False,
    ):
        """Quantize the model during export with graph mode quantization."""
        if force_quantize:
            trace = self.trace(inputs)
            if not qconfig_dict:
                qconfig_dict = {"": get_default_qconfig("fbgemm")}
            prepare_m = prepare_jit(trace, qconfig_dict, inplace=False)
            prepare_m.eval()
            with torch.no_grad():
                for i, (_, batch) in enumerate(data_loader):
                    print("Running calibration with batch {}".format(i))
                    input_data = self.onnx_trace_input(batch)
                    prepare_m(*input_data)
                    if i == calibration_num_batches - 1:
                        break
            trace = convert_jit(prepare_m, inplace=True)
        else:
            super().quantize()
            trace = self.trace(inputs)

        return trace


class SELFIE(RoBERTa):
    class Config(RoBERTa.Config):
        use_selfie: bool = True

    def forward(
        self, encoder_inputs: Tuple[torch.Tensor, ...], *args
    ) -> List[torch.Tensor]:
        if self.encoder.output_encoded_layers:
            # if encoded layers are returned, discard them
            representation = self.encoder(encoder_inputs, args[0])[1]
        else:
            representation = self.encoder(encoder_inputs, args[0])[0]
        return self.decoder(representation)


class RoBERTaRegression(NewBertRegressionModel):
    class Config(NewBertRegressionModel.Config):
        class RegressionModelInput(ConfigBase):
            tokens: RoBERTaTensorizer.Config = RoBERTaTensorizer.Config()
            labels: NumericLabelTensorizer.Config = NumericLabelTensorizer.Config()

        inputs: RegressionModelInput = RegressionModelInput()
        encoder: RoBERTaEncoderBase.Config = RoBERTaEncoderJit.Config()

    def torchscriptify(self, tensorizers, traced_model):
        """Using the traced model, create a ScriptModule which has a nicer API that
        includes generating tensors from simple data types, and returns classified
        values according to the output layer (eg. as a dict mapping class name to score)
        """
        script_tensorizer = tensorizers["tokens"].torchscriptify()
        return ScriptPyTextModule(
            model=traced_model,
            output_layer=self.output_layer.torchscript_predictions(),
            tensorizer=script_tensorizer,
        )


class RoBERTaWordTaggingModel(BaseModel):
    """
    Single Sentence Token-level Classification Model using XLM.
    """

    class Config(BaseModel.Config):
        class WordTaggingInputConfig(ConfigBase):
            tokens: RoBERTaTokenLevelTensorizer.Config = (
                RoBERTaTokenLevelTensorizer.Config()
            )

        inputs: WordTaggingInputConfig = WordTaggingInputConfig()
        encoder: RoBERTaEncoderBase.Config = RoBERTaEncoderJit.Config()
        decoder: MLPDecoder.Config = MLPDecoder.Config()
        output_layer: WordTaggingOutputLayer.Config = WordTaggingOutputLayer.Config()

    @classmethod
    def from_config(cls, config: Config, tensorizers: Dict[str, Tensorizer]):
        label_vocab = tensorizers["tokens"].labels_vocab
        vocab = tensorizers["tokens"].vocab

        encoder = create_module(
            config.encoder,
            output_encoded_layers=True,
            padding_idx=vocab.get_pad_index(),
            vocab_size=vocab.__len__(),
        )
        decoder = create_module(
            config.decoder, in_dim=encoder.representation_dim, out_dim=len(label_vocab)
        )
        output_layer = create_module(config.output_layer, labels=label_vocab)
        return cls(encoder, decoder, output_layer)

    def __init__(self, encoder, decoder, output_layer, stage=Stage.TRAIN) -> None:
        super().__init__(stage=stage)
        self.encoder = encoder
        self.decoder = decoder
        self.module_list = [encoder, decoder]
        self.output_layer = output_layer
        self.stage = stage
        log_class_usage(__class__)

    def arrange_model_inputs(self, tensor_dict):
        tokens, pad_mask, segment_labels, positions, _ = tensor_dict["tokens"]
        model_inputs = (tokens, pad_mask, segment_labels, positions)
        return (model_inputs,)

    def arrange_targets(self, tensor_dict):
        _, _, _, _, labels = tensor_dict["tokens"]
        return labels

    def forward(self, encoder_inputs: Tuple[torch.Tensor, ...], *args) -> torch.Tensor:
        # The encoder outputs a list of representations for each token where
        # every element of the list corresponds to a layer in the transformer.
        # We extract and pass the representations associated with the last layer
        # of the transformer.
        representation = self.encoder(encoder_inputs)[0][-1]
        return self.decoder(representation, *args)


class RoBERTaR3F(RoBERTa, R3FPyTextMixin):
    class Config(RoBERTa.Config):
        r3f_options: R3FConfigOptions = R3FConfigOptions()

    def get_embedding_module(self, *args, **kwargs):
        return self.encoder.encoder.transformer.token_embedding

    def original_forward(self, *args, **kwargs):
        return RoBERTa.forward(self, *args, **kwargs)

    def get_sample_size(self, model_inputs, targets):
        return targets.size(0)

    def __init__(
        self, encoder, decoder, output_layer, r3f_options, stage=Stage.TRAIN
    ) -> None:
        RoBERTa.__init__(self, encoder, decoder, output_layer, stage=stage)
        R3FPyTextMixin.__init__(self, r3f_options)

    def forward(self, *args, use_r3f: bool = False, **kwargs):
        return R3FPyTextMixin.forward(self, *args, use_r3f=use_r3f, **kwargs)

    @classmethod
    def train_batch(cls, model, batch, state=None):
        return R3FPyTextMixin.train_batch(model=model, batch=batch, state=state)
